#include "kernel_operator.h"

using namespace AscendC;

constexpr int32_t TQueDepth = 1;
constexpr int32_t BUFFER_NUM = 2;

// 多少个T是32Byte
constexpr int num32 = 8;

class ScatterReduceKernelSP {
private:
    TPipe* pipe;
    TQue<TPosition::VECIN, TQueDepth> inQueueSelf;
    TQue<TPosition::VECIN, TQueDepth> inQueueIndex;
    TQue<TPosition::VECIN, TQueDepth> inQueueSrc;
    TQue<TPosition::VECOUT, TQueDepth> outQueueY;
    TQue<TPosition::VECCALC, TQueDepth> calcQueueNum;
    GlobalTensor<float> selfGm;
    GlobalTensor<int32_t> indexGm;
    GlobalTensor<float> srcGm;
    GlobalTensor<float> yGm;
    // 例如输入Tensor各维度大小是M*N*K，指定维度是1，即N的那一维
    // 指定维度前面总共的维度的元素个数，即M
    int beforeNum;
    // 每次多少个元素进行ReduceMax（指定维度的元素个数），即N
    int calcNum;
    // 每次选择的时候要跳过多少个元素（指定维度后面的维度总共的元素个数），即K
    int afterNum;
    // 总共有多少个元素，即M*N*K
    int totalNum;

    int dim;
    int reduce;
    int include_self;

    int blockIdx;
    int blockNum;
    int blockCalcNum;
    int alreadyCalcNum;
    int thisBlockCalcNum;
    int loopNum;
    int lastLoopCalcNum;
    int thisLoopCalcNum;

    // int beforeIdx;
    // int afterIdx;
    // int offsetX;
    LocalTensor<float> nowLocal;
    LocalTensor<int32_t> indexLocal;
    LocalTensor<float> srcLocal;
    LocalTensor<int32_t> numLocal;

public:
    __aicore__ inline ScatterReduceKernelSP() {}
    __aicore__ inline void Init(GM_ADDR self, GM_ADDR index, GM_ADDR src, GM_ADDR y, GM_ADDR workspace, TPipe* pipeIn, int beforeNum, int calcNum, int afterNum, int totalNum, int dim, int reduce, int include_self)
    {
        pipe = pipeIn;
        this->beforeNum = beforeNum;
        this->calcNum = calcNum;
        this->afterNum = afterNum;
        this->totalNum = totalNum;

        this->dim = dim;
        this->reduce = reduce;
        this->include_self = include_self;

        // printf("beforeNum: %d, calcNum: %d, afterNum: %d, totalNum: %d\n", beforeNum, calcNum, afterNum, totalNum);
        // printf("dim: %d, reduce: %d, include_self: %d\n", dim, reduce, include_self);
        
        blockIdx = GetBlockIdx();
        blockNum = GetBlockNum();

        // 没有beforeNum==1时的通用情况
        /*
        // 计算数量：blockCalcNum * blockNum >= beforeNum * afterNum
        // blockCalcNum >= beforeNum * afterNum / blockNum
        blockCalcNum = (beforeNum * afterNum + blockNum - 1) / blockNum;
        alreadyCalcNum = blockIdx * blockCalcNum;
        int remainCalcNum = beforeNum * afterNum - alreadyCalcNum;
        remainCalcNum = remainCalcNum >= 0 ? remainCalcNum : 0;
        // 如果是尾核，就会：thisBlockCalcNum = beforeNum * afterNum - blockIdx * blockCalcNum;
        thisBlockCalcNum = blockCalcNum <= remainCalcNum ? blockCalcNum : remainCalcNum;
        // printf("blockIdx %d, blockNum %d, alreadyCalcNum %d, remainCalcNum %d, thisBlockCalcNum %d\n", blockIdx, blockNum, alreadyCalcNum, remainCalcNum, thisBlockCalcNum);
        // 为了double buffer，把thisBlockCalcNum分两半，上半上取整，下半下取整
        // 不不不，为了节约流水线的前后的单出来的CopyIn和CopyOut的时间，还是每次只算32对齐的即可吧。之后还能再改成更多对齐的，对比看怎么样更快
        beforeIdx = alreadyCalcNum / afterNum;
        afterIdx = alreadyCalcNum % afterNum;
        offsetX = beforeIdx * calcNum * afterNum + afterIdx;
        */

        // beforeNum == 1
        blockCalcNum = (afterNum + blockNum - 1) / blockNum;
        // 对num32上取整
        blockCalcNum = (blockCalcNum + num32 - 1) / num32 * num32;
        alreadyCalcNum = blockIdx * blockCalcNum;
        int remainCalcNum = afterNum - alreadyCalcNum;
        remainCalcNum = remainCalcNum >= 0 ? remainCalcNum : 0;
        thisBlockCalcNum = blockCalcNum <= remainCalcNum ? blockCalcNum : remainCalcNum;
        // offsetX = afterIdx = alreadyCalcNum
        loopNum = (thisBlockCalcNum + num32 - 1) / num32;
        lastLoopCalcNum = thisBlockCalcNum - (loopNum - 1) * num32;
        // printf("blockIdx %d, blockNum %d, blockCalcNum %d, alreadyCalcNum %d, remainCalcNum %d, thisBlockCalcNum %d, loopNum %d, lastLoopCalcNum %d\n", blockIdx, blockNum, blockCalcNum, alreadyCalcNum, remainCalcNum, thisBlockCalcNum, loopNum, lastLoopCalcNum);

        selfGm.SetGlobalBuffer((__gm__ float *)self, totalNum);
        indexGm.SetGlobalBuffer((__gm__ int32_t *)index, totalNum);
        srcGm.SetGlobalBuffer((__gm__ float *)src, totalNum);
        yGm.SetGlobalBuffer((__gm__ float *)y, totalNum);

        int loopCalcSize = calcNum * num32 * 4;

        pipe->InitBuffer(inQueueSelf, BUFFER_NUM, loopCalcSize);
        pipe->InitBuffer(inQueueIndex, BUFFER_NUM, loopCalcSize);
        pipe->InitBuffer(inQueueSrc, BUFFER_NUM, loopCalcSize);
        pipe->InitBuffer(outQueueY, BUFFER_NUM, loopCalcSize);
        pipe->InitBuffer(calcQueueNum, BUFFER_NUM, loopCalcSize);
    }
    __aicore__ inline void Process()
    {
        if (loopNum == 0) {
            return;
        }
        // 每次弄thisLoopCalcNum（前面都是num32，最后一次是lastLoopCalcNum）个列（每列calcNum个元素）
        thisLoopCalcNum = num32;
        for (int i = 0; i < loopNum - 1; i++) {
            CopyIn(i);
            Compute(i);
            CopyOut(i);
        }
        thisLoopCalcNum = lastLoopCalcNum;
        CopyIn(loopNum - 1);
        Compute(loopNum - 1);
        CopyOut(loopNum - 1);
    }
    __aicore__ inline void CopyIn(int loopIdx)
    {
        nowLocal = inQueueSelf.AllocTensor<float>();
        indexLocal = inQueueIndex.AllocTensor<int32_t>();
        srcLocal = inQueueSrc.AllocTensor<float>();
        DataCopyPadExtParams<float> padParams{false, 0, 0, 0};
        // VECIN和VECOUT是32B单位，其它都是1B单位
        DataCopyExtParams copyParams = {
            (uint16_t)calcNum, 
            (uint32_t)thisLoopCalcNum * 4, 
            (uint32_t)((afterNum - thisLoopCalcNum) * sizeof(float)), 
            0,
            0
        };
        DataCopyPad(nowLocal, selfGm[alreadyCalcNum + loopIdx * num32], copyParams, padParams);
        DataCopyPad(srcLocal, srcGm[alreadyCalcNum + loopIdx * num32], copyParams, padParams);
        DataCopyPadExtParams<int32_t> padParamsIndex{false, 0, 0, 0};
        DataCopyExtParams copyParamsIndex = {
            (uint16_t)calcNum, 
            (uint32_t)thisLoopCalcNum * 4, 
            (uint32_t)((afterNum - thisLoopCalcNum) * sizeof(int32_t)), 
            0,
            0
        };
        DataCopyPad(indexLocal, indexGm[alreadyCalcNum + loopIdx * num32], copyParamsIndex, padParamsIndex);
        inQueueSelf.EnQue(nowLocal);
        inQueueIndex.EnQue(indexLocal);
        inQueueSrc.EnQue(srcLocal);
    }
    __aicore__ inline void Compute(int loopIdx)
    {
        nowLocal = inQueueSelf.DeQue<float>();
        indexLocal = inQueueIndex.DeQue<int32_t>();
        srcLocal = inQueueSrc.DeQue<float>();
        numLocal = calcQueueNum.AllocTensor<int32_t>();
        // for (int i = 0; i < calcNum * num32; i++) {
        //     numLocal(i) = 0;
        // }
        Duplicate(numLocal, 0, calcNum * num32);
        for (int i = 0; i < calcNum; i++) {
            for (int j = 0; j < thisLoopCalcNum; j++) {
                int idx1 = i * num32 + j;
                float src = srcLocal(idx1);
                int32_t index = indexLocal(idx1);
                int idx2 = index * num32 + j;
                float now = nowLocal(idx2);
                int32_t num = numLocal(idx2);
                if (num == 0) {
                    // init
                    now = src;
                    numLocal(idx2) = num + 1;
                    nowLocal(idx2) = now;
                } else if (src < now) {
                    // amin
                    now = src;
                    nowLocal(idx2) = now;
                }
            }
        }
        outQueueY.EnQue(nowLocal);
        inQueueIndex.FreeTensor(indexLocal);
        inQueueSrc.FreeTensor(srcLocal);
        calcQueueNum.FreeTensor(numLocal);
    }
    __aicore__ inline void CopyOut(int loopIdx)
    {
        nowLocal = outQueueY.DeQue<float>();
        // VECIN和VECOUT是32B单位，其它都是1B单位
        DataCopyExtParams copyParams = {
            (uint16_t)calcNum, 
            (uint32_t)thisLoopCalcNum * 4, 
            0, 
            (uint32_t)((afterNum - thisLoopCalcNum) * sizeof(float)),
            0
        };
        DataCopyPad(yGm[alreadyCalcNum + loopIdx * num32], nowLocal, copyParams);
        inQueueSelf.FreeTensor(nowLocal);
    }
};

template<typename T> class ScatterReduceKernel {
private:
    TPipe* pipe;
    TQue<TPosition::VECIN, TQueDepth> inQueueSelf;
    TQue<TPosition::VECIN, TQueDepth> inQueueIndex;
    TQue<TPosition::VECIN, TQueDepth> inQueueSrc;
    TQue<TPosition::VECOUT, TQueDepth> outQueueY;
    TQue<TPosition::VECCALC, TQueDepth> calcQueueNum;
    GlobalTensor<T> selfGm;
    GlobalTensor<int32_t> indexGm;
    GlobalTensor<T> srcGm;
    GlobalTensor<T> yGm;
    // 例如输入Tensor各维度大小是M*N*K，指定维度是1，即N的那一维
    // 指定维度前面总共的维度的元素个数，即M
    int beforeNum;
    // 每次多少个元素进行ReduceMax（指定维度的元素个数），即N
    int calcNum;
    // 每次选择的时候要跳过多少个元素（指定维度后面的维度总共的元素个数），即K
    int afterNum;
    // 总共有多少个元素，即M*N*K
    int totalNum;

    // 多少个T是32Byte
    int num32;
    int num32index;
    int dim;
    int reduce;
    int include_self;

    int blockIdx;
    int beforeIdx;
    int afterIdx;
    int offsetX;
    LocalTensor<T> nowLocal;
    LocalTensor<int32_t> indexLocal;
    LocalTensor<T> srcLocal;
    LocalTensor<int32_t> numLocal;

public:
    __aicore__ inline ScatterReduceKernel() {}
    __aicore__ inline void Init(GM_ADDR self, GM_ADDR index, GM_ADDR src, GM_ADDR y, GM_ADDR workspace, TPipe* pipeIn, int beforeNum, int calcNum, int afterNum, int totalNum, int dim, int reduce, int include_self)
    {
        pipe = pipeIn;
        this->beforeNum = beforeNum;
        this->calcNum = calcNum;
        this->afterNum = afterNum;
        this->totalNum = totalNum;

        this->dim = dim;
        this->reduce = reduce;
        this->include_self = include_self;
        num32 = 32 / sizeof(T);
        num32index = 32 / sizeof(int32_t);

        // printf("beforeNum: %d, calcNum: %d, afterNum: %d, totalNum: %d\n", beforeNum, calcNum, afterNum, totalNum);
        // printf("dim: %d, reduce: %d, include_self: %d\n", dim, reduce, include_self);
        
        blockIdx = GetBlockIdx();
        beforeIdx = blockIdx / afterNum;
        afterIdx = blockIdx % afterNum;
        offsetX = beforeIdx * calcNum * afterNum + afterIdx;
        // printf("blockIdx %d, beforeIdx %d, afterIdx %d, offsetX %d\n", blockIdx, beforeIdx, afterIdx, offsetX);
        selfGm.SetGlobalBuffer((__gm__ T *)self, totalNum);
        indexGm.SetGlobalBuffer((__gm__ int32_t *)index, totalNum);
        srcGm.SetGlobalBuffer((__gm__ T *)src, totalNum);
        yGm.SetGlobalBuffer((__gm__ T *)y, totalNum);

        pipe->InitBuffer(inQueueSelf, BUFFER_NUM, calcNum * 32);
        pipe->InitBuffer(inQueueIndex, BUFFER_NUM, calcNum * 32);
        pipe->InitBuffer(inQueueSrc, BUFFER_NUM, calcNum * 32);
        pipe->InitBuffer(outQueueY, BUFFER_NUM, calcNum * 32);
        pipe->InitBuffer(calcQueueNum, BUFFER_NUM, calcNum * 32);
    }
    __aicore__ inline void Process()
    {
        CopyIn();
        Compute();
        CopyOut();
    }
    __aicore__ inline void CopyIn()
    {
        nowLocal = inQueueSelf.AllocTensor<T>();
        indexLocal = inQueueIndex.AllocTensor<int32_t>();
        srcLocal = inQueueSrc.AllocTensor<T>();
        DataCopyPadExtParams<T> padParams{true, 0, 0, 0};
        // VECIN和VECOUT是32B单位，其它都是1B单位
        DataCopyExtParams copyParams = {
            (uint16_t)calcNum, 
            (uint32_t)(1 * sizeof(T)), 
            (uint32_t)((afterNum - 1) * sizeof(T)), 
            0,
            0
        };
        DataCopyPad(nowLocal, selfGm[offsetX], copyParams, padParams);
        DataCopyPad(srcLocal, srcGm[offsetX], copyParams, padParams);
        DataCopyPadExtParams<int32_t> padParamsIndex{true, 0, 0, 0};
        DataCopyExtParams copyParamsIndex = {
            (uint16_t)calcNum, 
            (uint32_t)(1 * sizeof(int32_t)), 
            (uint32_t)((afterNum - 1) * sizeof(int32_t)), 
            0,
            0
        };
        DataCopyPad(indexLocal, indexGm[offsetX], copyParamsIndex, padParamsIndex);
        inQueueSelf.EnQue(nowLocal);
        inQueueIndex.EnQue(indexLocal);
        inQueueSrc.EnQue(srcLocal);
    }
    __aicore__ inline void Compute()
    {
        nowLocal = inQueueSelf.DeQue<T>();
        indexLocal = inQueueIndex.DeQue<int32_t>();
        srcLocal = inQueueSrc.DeQue<T>();
        numLocal = calcQueueNum.AllocTensor<int32_t>();
        for (int i = 0; i < calcNum; i++) {
            // int num = numLocal(i);
            // printf("num[%d]: %d\n", i, num);
            numLocal(i * num32index) = 0;
        }
        // printf("numLocal: ");
        // for (int i = 0; i < calcNum; i++) {
        //     int32_t num = numLocal(i * num32index);
        //     printf("%d ", num);
        // }
        // printf("\n");
        // // now
        // printf("nowLocal: ");
        // for (int i = 0; i < calcNum; i++) {
        //     float now = nowLocal(i * num32);
        //     printf("%f ", now);
        // }
        // printf("\n");
        // // index
        // printf("indexLocal: ");
        // for ( int i = 0; i < calcNum; i++) {
        //     int32_t indexValue = indexLocal(i * num32index);
        //     printf("%d ", indexValue);
        // }
        // printf("\n");
        // // src
        // printf("srcLocal: ");
        // for (int i = 0; i < calcNum; i++) {
        //     float srcValue = srcLocal(i * num32);
        //     printf("%f ", srcValue);
        // }
        // printf("\n");
        for (int i = 0; i < calcNum; i++) {
            float src = srcLocal(i * num32);
            int32_t index = indexLocal(i * num32index);
            float now = nowLocal(index * num32);
            int32_t num = numLocal(index * num32index);
            if(num == 0 && include_self == 0){
                // init
                now = src;
            }else if(reduce == 0){
                // sum
                now += src;
            }else if(reduce == 1){
                // prod
                now *= src;
            }else if(reduce == 2){
                // mean
                now += src;
            }else if(reduce == 3){
                // amax
                if(src > now){
                    now = src;
                }
            }else if(reduce == 4){
                // amin
                if(src < now){
                    now = src;
                }
            }
            numLocal(index * num32index) = num + 1;
            nowLocal(index * num32) = now;
        }
        if(reduce == 2){
            // mean
            for(int i = 0; i < calcNum; i++){
                float num = numLocal(i * num32index);
                if(num != 0){
                    float now = nowLocal(i * num32);
                    now /= num;
                    nowLocal(i * num32) = now;
                }
            }
        }
        // now
        // printf("nowLocal: ");
        // for (int i = 0; i < calcNum; i++) {
        //     float now = nowLocal(i * num32);
        //     printf("%f ", now);
        // }
        // printf("\n");
        inQueueIndex.FreeTensor(indexLocal);
        inQueueSrc.FreeTensor(srcLocal);
        outQueueY.EnQue(nowLocal);
    }
    __aicore__ inline void CopyOut()
    {
        nowLocal = outQueueY.DeQue<T>();
        // now
        // printf("nowLocal: ");
        // for (int i = 0; i < calcNum; i++) {
        //     float now = nowLocal(i * num32);
        //     printf("%f ", now);
        // }
        // printf("\n");
        // VECIN和VECOUT是32B单位，其它都是1B单位
        DataCopyExtParams copyParams = {
            (uint16_t)calcNum, 
            (uint32_t)(1 * sizeof(T)), 
            0, 
            (uint32_t)((afterNum - 1) * sizeof(T)),
            0
        };
        DataCopyPad(yGm[offsetX], nowLocal, copyParams);
        inQueueSelf.FreeTensor(nowLocal);
    }
};
    

extern "C" __global__ __aicore__ void scatter_reduce(GM_ADDR self, GM_ADDR index, GM_ADDR src, GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling) {
    GET_TILING_DATA(tiling_data, tiling);
    TPipe pipe;
    if (TILING_KEY_IS(0)) {
        ScatterReduceKernelSP scatterReduce;
        scatterReduce.Init(self, index, src, y, workspace, &pipe, tiling_data.beforeNum, tiling_data.calcNum, tiling_data.afterNum, tiling_data.totalNum, tiling_data.dim, tiling_data.reduce, tiling_data.include_self);
        scatterReduce.Process();
    } else if (TILING_KEY_IS(1)) {
        ScatterReduceKernel<DTYPE_SELF> scatterReduce;
        scatterReduce.Init(self, index, src, y, workspace, &pipe, tiling_data.beforeNum, tiling_data.calcNum, tiling_data.afterNum, tiling_data.totalNum, tiling_data.dim, tiling_data.reduce, tiling_data.include_self);
        scatterReduce.Process();
    }
}